"""
Post-process Mindboggle-101 volume images for distribution,
using Mindboggle, FreeSurfer, and FSL tools.

This is modified from the original code_postprocess_Mindboggle101.py to
(1) only generate DKT31 (not DKT25) labeling protocol data
(2) use the existing T1 data and transforms (don't generate new ones):

  - Convert label volume from FreeSurfer to original space
  x Extract brain by masking with manual cortical and automated subcortical labels
  - Remove non-DKT31 (non-cortical) labels
  - Affine register T1-weighted brain to MNI152 brain
  - Transfer whole-head images with affine transform
  - Transfer labeled images with affine transform (nearest-neighbor interpolation)

Authors:  Arno Klein  .  arno@mindboggle.info  .  www.binarybottle.com

(c) 2013-2019  Mindbogglers (www.mindboggle.info), under Apache License Version 2.0

"""

import os


# Paths, template, and label conversion files
mb101_path = os.path.join('/Users', 'arno.klein', 'Data', 'Mindboggle101')
mb_info_path = os.path.join(mb101_path, 'docs')
template = os.path.join(mb101_path, 'MNI152_T1_1mm_brain.nii.gz')

# Loop through subjects
list_file = os.path.join(mb_info_path, 'mindboggle101_list.txt')
fid = open(list_file, 'r')
subjects = fid.readlines()
subjects = [''.join(x.split()) for x in subjects]


def keep_volume_labels(input_file, labels_to_keep, output_file='',
                       second_file=''):
    """
    Keep only given labels in an image volume (or use to mask second volume).

    Parameters
    ----------
    input_file : string
        labeled nibabel-readable (e.g., nifti) file
    labels_to_keep : list of integers
        labels to keep
    output_file : string
        output file name
    second_file : string
        second nibabel-readable file (keep/erase voxels in this file instead)

    Returns
    -------
    output_file : string
        output file name

    Examples
    --------
    >>> # Remove right hemisphere labels
    >>> import os
    >>> from mindboggle.guts.relabel import keep_volume_labels
    >>> from mindboggle.mio.labels import DKTprotocol
    >>> from mindboggle.mio.fetch_data import prep_tests
    >>> urls, fetch_data = prep_tests()
    >>> input_file = fetch_data(urls['freesurfer_labels'], '', '.nii.gz')
    >>> second_file = ''
    >>> labels_to_keep = list(range(1000, 1036))
    >>> output_file = 'keep_volume_labels.nii.gz'
    >>> output_file = keep_volume_labels(input_file, labels_to_keep,
    ...                                  output_file, second_file)

    View nifti file (skip test):

    >>> from mindboggle.mio.plots import plot_volumes
    >>> plot_volumes(output_file) # doctest: +SKIP

    """
    import os
    import numpy as np
    import nibabel as nb

    # ------------------------------------------------------------------------
    # Load labeled image volume and extract data as 1-D array:
    # ------------------------------------------------------------------------
    vol = nb.load(input_file)
    xfm = vol.get_affine()
    data = vol.get_data().ravel()

    # ------------------------------------------------------------------------
    # If second file specified, erase voxels whose corresponding
    # voxels in the input_file have labels not in labels_to_keep:
    # ------------------------------------------------------------------------
    if second_file:
        # Load second image volume and extract data as 1-D array:
        vol = nb.load(second_file)
        xfm = vol.get_affine()
        new_data = vol.get_data().ravel()
        if not output_file:
            output_file = os.path.join(os.getcwd(),
                                       os.path.basename(second_file))
    # ------------------------------------------------------------------------
    # If second file not specified, remove labels not in labels_to_keep:
    # ------------------------------------------------------------------------
    else:
        new_data = data.copy()
        if not output_file:
            output_file = os.path.join(os.getcwd(),
                                       os.path.basename(input_file))

    # ------------------------------------------------------------------------
    # Erase voxels as specified above:
    # ------------------------------------------------------------------------
    ulabels = np.unique(data)
    for label in ulabels:
        label = int(label)
        if label not in labels_to_keep:
            new_data[np.where(data == label)[0]] = 0

    # ------------------------------------------------------------------------
    # Reshape to original dimensions:
    # ------------------------------------------------------------------------
    new_data = np.reshape(new_data, vol.shape)

    # ------------------------------------------------------------------------
    # Save relabeled file:
    # ------------------------------------------------------------------------
    img = nb.Nifti1Image(new_data, xfm)
    img.to_filename(output_file)

    if not os.path.exists(output_file):
        raise IOError("keep_volume_labels() did not create " + output_file + ".")

    return output_file


for subject in subjects:

    print(">>> Process subject: {0}...".format(subject))
    subject_path = os.path.join(mb101_path, 'subjects', subject, 'mri')

    # Identify original files
    full_labels_orig = os.path.join(subject_path, 'aparcNMMjt+aseg.nii.gz')
    head = os.path.join(subject_path, 't1weighted.nii.gz')
    brain = os.path.join(subject_path, 't1weighted_brain.nii.gz')

    # Name all output files
    full_labels = os.path.join(subject_path, 'labels.DKT31.manual+aseg.nii.gz')
    DKT31_labels = os.path.join(subject_path, 'labels.DKT31.manual.nii.gz')
    xfm_matrix = os.path.join(subject_path, 't1weighted_brain.MNI152.affine.txt')
    xfm_brain = os.path.join(subject_path, 't1weighted_brain.MNI152.nii.gz')
    xfm_head = os.path.join(subject_path, 't1weighted.MNI152.nii.gz')
    xfm_DKT31 = os.path.join(subject_path, 'labels.DKT31.manual.MNI152.nii.gz')
    xfm_DKT31aseg = os.path.join(subject_path, 'labels.DKT31.manual+aseg.MNI152.nii.gz')



    # Remove old labels and affine-transformed files
    rm_files = [x for x in os.listdir(subject_path) if 'labels.' in x or '.MNI152.' in x]
    for rm_file in rm_files:
        os.remove(os.path.join(subject_path, rm_file))

    # Convert label volume from FreeSurfer to original space
    print("Convert label volume from FreeSurfer to original space...")
    cmd = ' '.join(['mri_vol2vol --nearest --mov', full_labels_orig, '--targ', head,
                    '--regheader --o', full_labels])
    print(cmd); os.system(cmd)



    # Affine register T1-weighted brain to MNI152 brain using FSL's flirt
    print("Affine register T1-weighted brain to MNI152 brain using FSL's flirt...")
    cmd = ' '.join(['flirt', '-in', brain, '-ref', template,
                    '-out', xfm_brain, '-omat', xfm_matrix])
    print(cmd); os.system(cmd)

    # Transfer whole-head images with affine transform using FSL's flirt
    print("Apply affine transform to whole-head using FSL's flirt...")
    cmd = ' '.join(['flirt', '-in', head, '-ref', template,
                    '-applyxfm -init', xfm_matrix, '-out', xfm_head])
    print(cmd); os.system(cmd)

    # Transfer DKT31- plus FreeSurfer-aseg-labeled images with affine transform (nearest-neighbor interpolation)
    print("Apply affine transform to labeled images (with nearest neighbor interpolation)...")

    cmd = ' '.join(['flirt', '-in', full_labels, '-ref', template,
                    '-applyxfm -init', xfm_matrix,
                    '-interp nearestneighbour -out', xfm_DKT31aseg])
    print(cmd); os.system(cmd)



    # Remove all but DKT31 (cortical) labels
    print("Remove non-DKT31 (cortical) labels...")
    DKT31_numbers = [2, 3] + list(range(5, 32)) + [34, 35]
    labels_to_keep = [1000 + x for x in DKT31_numbers]
    labels_to_keep.extend([2000 + x for x in DKT31_numbers])
    output_file = keep_volume_labels(full_labels, labels_to_keep, output_file=DKT31_labels, second_file='')

    # Transfer DKT31-labeled images with affine transform (nearest-neighbor interpolation)
    cmd = ' '.join(['flirt', '-in', DKT31_labels, '-ref', template,
                    '-applyxfm -init', xfm_matrix,
                    '-interp nearestneighbour -out', xfm_DKT31])
    print(cmd); os.system(cmd)


    # Compress subject directory
    #subject_path2 = os.path.join(mb101_path, 'subjects', subject)
    #cmd =  ' '.join(['tar cvfz', subject_path2+'.tar.gz', subject_path2])
    #print(cmd); os.system(cmd)

